{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Classification with Encrypted Neural Networks\n",
    "\n",
    "In this tutorial, we'll look at how we can achieve the <i>Model Hiding</i> application we discussed in the Introduction. That is, suppose say Alice has a trained model she wishes to keep private, and Bob has some data he wishes to classify while keeping it private. We will see how CrypTen allows Alice and Bob to coordinate and classify the data, while achieving their privacy requirements.\n",
    "\n",
    "To simulate this scenario, we will begin with Alice training a simple neural network on MNIST data. Then we'll see how Alice and Bob encrypt their network and data respectively, classify the encrypted data and finally decrypt the labels.\n",
    "\n",
    "## Setup\n",
    "\n",
    "We first import the `torch` and `crypten` libraries, and initialize `crypten`. We will use a helper script `mnist_utils.py` to split the public MNIST data into Alice's portion and Bob's portion. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:module 'torchvision.models.mobilenet' has no attribute 'ConvBNReLU'\n"
     ]
    }
   ],
   "source": [
    "import crypten\n",
    "import torch\n",
    "\n",
    "crypten.init()\n",
    "torch.set_num_threads(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run script that downloads the publicly available MNIST data, and splits the data as required.\n",
    "%run ./mnist_utils.py --option train_v_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we will define the structure of Alice's network as a class. Even though Alice has a pre-trained model, the CrypTen will require this structure as input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Alice's network\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class AliceNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AliceNet, self).__init__()\n",
    "        self.fc1 = nn.Linear(784, 128)\n",
    "        self.fc2 = nn.Linear(128, 128)\n",
    "        self.fc3 = nn.Linear(128, 10)\n",
    " \n",
    "    def forward(self, x):\n",
    "        out = self.fc1(x)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc3(out)\n",
    "        return out\n",
    "    \n",
    "crypten.common.serial.register_safe_class(AliceNet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will also define a helper routine `compute_accuracy` to make it easy to compute the accuracy of the output we get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(output, labels):\n",
    "    pred = output.argmax(1)\n",
    "    correct = pred.eq(labels)\n",
    "    correct_count = correct.sum(0, keepdim=True).float()\n",
    "    accuracy = correct_count.mul_(100.0 / output.size(0))\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Encrypting a Pre-trained Model\n",
    "\n",
    "Assume that Alice has a pre-trained network ready to classify data. Let's see how we can use CrypTen to encrypt this network, so it can be used to classify data without revealing its parameters. We'll use the pre-trained model in `models/tutorial4_alice_model.pth` in this tutorial. As in Tutorial 3, we will assume Alice is using the rank 0 process, while Bob is using the rank 1 process. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ALICE = 0\n",
    "BOB = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In CrypTen, encrypting PyTorch network is straightforward: we load a PyTorch model from file to the appropriate source, convert it to a CrypTen model and then encrypt it. Let us understand each of these steps.\n",
    "\n",
    "As we did with CrypTensors in Tutorial 3, we will use CrypTen's load functionality (i.e., `crypten.load`) to read a model from file to a particular source. The source is indicated by the keyword argument `src`. As in Tutorial 3, this src argument tells us the rank of the party we want to load the model to (and later, encrypt the model from). In addition, here we also need to provide a dummy model to tell CrypTen the model's structure. The dummy model is indicated by the keyword argument `dummy_model`. Note that unlike loading a tensor, the result from `crypten.load` is not encrypted. Instead, only the `src` party's model is populated from the file.\n",
    "\n",
    "Once the model is loaded, we call the function `from_pytorch`: this function sets up a CrypTen network from the PyTorch network. It takes the plaintext network as input as well as dummy input. The dummy input must be a `torch` tensor of the same shape as a potential input to the network, however the values inside the tensor do not matter.  \n",
    "\n",
    "Finally, we call `encrypt` on the CrypTen network to encrypt its parameters. Once we call the `encrypt` function, the models `encrypted` property will verify that the model parameters have been encrypted. (Encrypted CrypTen networks can also be decrypted using the `decrypt` function)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AliceNet(\n",
      "  (fc1): Linear(in_features=784, out_features=128, bias=True)\n",
      "  (fc2): Linear(in_features=128, out_features=128, bias=True)\n",
      "  (fc3): Linear(in_features=128, out_features=10, bias=True)\n",
      ")\n",
      "Model successfully encrypted: True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brianknott/Desktop/CrypTen/crypten/nn/onnx_converter.py:153: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /Users/distiller/project/conda/conda-bld/pytorch_1616554799287/work/torch/csrc/utils/tensor_numpy.cpp:143.)\n",
      "  for t in onnx_model.graph.initializer\n"
     ]
    }
   ],
   "source": [
    "# Load pre-trained model to Alice\n",
    "dummy_model = AliceNet()\n",
    "plaintext_model = torch.load('models/tutorial4_alice_model.pth')\n",
    "\n",
    "print(plaintext_model)\n",
    "\n",
    "# Encrypt the model from Alice:    \n",
    "\n",
    "# 1. Create a dummy input with the same shape as the model input\n",
    "dummy_input = torch.empty((1, 784))\n",
    "\n",
    "# 2. Construct a CrypTen network with the trained model and dummy_input\n",
    "private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)\n",
    "\n",
    "# 3. Encrypt the CrypTen network with src=ALICE\n",
    "private_model.encrypt(src=ALICE)\n",
    "\n",
    "#Check that model is encrypted:\n",
    "print(\"Model successfully encrypted:\", private_model.encrypted)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classifying Encrypted Data with Encrypted Model\n",
    "\n",
    "We can now use Alice's encrypted network to classify Bob's data. For this, we need to encrypt Bob's data as well, as we did in Tutorial 3 (recall that Bob has the rank 1 process). Once Alice's network and Bob's data are both encrypted, CrypTen inference is performed with essentially identical steps as in PyTorch. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tAccuracy: 99.0000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[None, None]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import crypten.mpc as mpc\n",
    "import crypten.communicator as comm\n",
    "\n",
    "labels = torch.load('/tmp/bob_test_labels.pth').long()\n",
    "count = 100 # For illustration purposes, we'll use only 100 samples for classification\n",
    "\n",
    "@mpc.run_multiprocess(world_size=2)\n",
    "def encrypt_model_and_data():\n",
    "    # Load pre-trained model to Alice\n",
    "    model = crypten.load_from_party('models/tutorial4_alice_model.pth', src=ALICE)\n",
    "    \n",
    "    # Encrypt model from Alice \n",
    "    dummy_input = torch.empty((1, 784))\n",
    "    private_model = crypten.nn.from_pytorch(model, dummy_input)\n",
    "    private_model.encrypt(src=ALICE)\n",
    "    \n",
    "    # Load data to Bob\n",
    "    data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=BOB)\n",
    "    data_enc2 = data_enc[:count]\n",
    "    data_flatten = data_enc2.flatten(start_dim=1)\n",
    "\n",
    "    # Classify the encrypted data\n",
    "    private_model.eval()\n",
    "    output_enc = private_model(data_flatten)\n",
    "    \n",
    "    # Compute the accuracy\n",
    "    output = output_enc.get_plain_text()\n",
    "    accuracy = compute_accuracy(output, labels[:count])\n",
    "    crypten.print(\"\\tAccuracy: {0:.4f}\".format(accuracy.item()))\n",
    "    \n",
    "encrypt_model_and_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validating Encrypted Classification\n",
    "\n",
    "Finally, we will verify that CrypTen classification results in encrypted output, and that this output can be decrypted into meaningful labels. \n",
    "\n",
    "To see this, in this tutorial, we will just check whether the result is an encrypted tensor; in the next tutorial, we will look into the values of tensor and confirm the encryption. We will also decrypt the result. As we discussed before, Alice and Bob both have access to the decrypted output of the model, and can both use this to obtain the labels. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output tensor encrypted: True\n",
      "Decrypted labels:\n",
      " tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,\n",
      "        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2,\n",
      "        4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0,\n",
      "        2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 9, 3, 1, 4,\n",
      "        1, 7, 6, 9])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[None, None]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@mpc.run_multiprocess(world_size=2)\n",
    "def encrypt_model_and_data():\n",
    "    # Load pre-trained model to Alice\n",
    "    plaintext_model = crypten.load_from_party('models/tutorial4_alice_model.pth', src=ALICE)\n",
    "    \n",
    "    # Encrypt model from Alice \n",
    "    dummy_input = torch.empty((1, 784))\n",
    "    private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)\n",
    "    private_model.encrypt(src=ALICE)\n",
    "    \n",
    "    # Load data to Bob\n",
    "    data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=BOB)\n",
    "    data_enc2 = data_enc[:count]\n",
    "    data_flatten = data_enc2.flatten(start_dim=1)\n",
    "\n",
    "    # Classify the encrypted data\n",
    "    private_model.eval()\n",
    "    output_enc = private_model(data_flatten)\n",
    "    \n",
    "    # Verify the results are encrypted: \n",
    "    crypten.print(\"Output tensor encrypted:\", crypten.is_encrypted_tensor(output_enc)) \n",
    "\n",
    "    # Decrypting the result\n",
    "    output = output_enc.get_plain_text()\n",
    "\n",
    "    # Obtaining the labels\n",
    "    pred = output.argmax(dim=1)\n",
    "    crypten.print(\"Decrypted labels:\\n\", pred)\n",
    "    \n",
    "encrypt_model_and_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "This completes our tutorial. While we have used a simple network here to illustrate the concepts, CrypTen provides primitives to allow for encryption of substantially more complex networks. In our examples section, we demonstrate how CrypTen can be used to encrypt LeNet and ResNet, among others. \n",
    "\n",
    "Before exiting this tutorial, please clean up the files generated using the following code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filenames = ['/tmp/alice_train.pth', \n",
    "             '/tmp/alice_train_labels.pth', \n",
    "             '/tmp/bob_test.pth', \n",
    "             '/tmp/bob_test_labels.pth']\n",
    "\n",
    "for fn in filenames:\n",
    "    if os.path.exists(fn): os.remove(fn)"
   ]
  }
 ],
 "metadata": {
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "disseminate_notebook_id": {
   "notebook_id": "390894444956881"
  },
  "disseminate_notebook_info": {
   "bento_version": "20190826-030256",
   "description": "",
   "hide_code": false,
   "hipster_group": "",
   "kernel_build_info": {
    "error": "The file located at '/data/users/shobha/fbsource/fbcode/bento/kernels/local/cryptenk/TARGETS' could not be found."
   },
   "no_uii": true,
   "notebook_number": "139932",
   "others_can_edit": true,
   "reviewers": "",
   "revision_id": "375902760006757",
   "tags": "",
   "tasks": "",
   "title": "Tutorial 4 -- Classification with Encrypted Neural Networks"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
